{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "from fastai.callback.hook import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp vision.models.unet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dynamic UNet\n",
    "\n",
    "> Unet model using PixelShuffle ICNR upsampling that can be built on top of any pretrained architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "def _get_sz_change_idxs(sizes):\n",
    "    \"Get the indexes of the layers where the size of the activation changes.\"\n",
    "    feature_szs = [size[-1] for size in sizes]\n",
    "    sz_chg_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])\n",
    "    return sz_chg_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(_get_sz_change_idxs([[3,64,64], [16,64,64], [32,32,32], [16,32,32], [32,32,32], [16,16]]), [1,4])\n",
    "test_eq(_get_sz_change_idxs([[3,64,64], [16,32,32], [32,32,32], [16,32,32], [32,16,16], [16,16]]), [0,3])\n",
    "test_eq(_get_sz_change_idxs([[3,64,64]]), [])\n",
    "test_eq(_get_sz_change_idxs([[3,64,64], [16,32,32]]), [0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "class UnetBlock(Module):\n",
    "    \"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.\"\n",
    "    @delegates(ConvLayer.__init__)\n",
    "    def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,\n",
    "                 self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):\n",
    "        self.hook = hook\n",
    "        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type)\n",
    "        self.bn = BatchNorm(x_in_c)\n",
    "        ni = up_in_c//2 + x_in_c\n",
    "        nf = ni if final_div else ni//2\n",
    "        self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)\n",
    "        self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type,\n",
    "                               xtra=SelfAttention(nf) if self_attention else None, **kwargs)\n",
    "        self.relu = act_cls()\n",
    "        apply_init(nn.Sequential(self.conv1, self.conv2), init)\n",
    "\n",
    "    def forward(self, up_in):\n",
    "        s = self.hook.stored\n",
    "        up_out = self.shuf(up_in)\n",
    "        ssh = s.shape[-2:]\n",
    "        if ssh != up_out.shape[-2:]:\n",
    "            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')\n",
    "        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))\n",
    "        return self.conv2(self.conv1(cat_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class ResizeToOrig(Module):\n",
    "    \"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`.\"\n",
    "    def __init__(self, mode='nearest'): self.mode = mode\n",
    "    def forward(self, x):\n",
    "        if x.orig.shape[-2:] != x.shape[-2:]:\n",
    "            x = F.interpolate(x, x.orig.shape[-2:], mode=self.mode)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "class DynamicUnet(SequentialEx):\n",
    "    \"Create a U-Net from a given architecture.\"\n",
    "    def __init__(self, encoder, n_out, img_size, blur=False, blur_final=True, self_attention=False,\n",
    "                 y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,\n",
    "                 init=nn.init.kaiming_normal_, norm_type=None, **kwargs):\n",
    "        imsize = img_size\n",
    "        sizes = model_sizes(encoder, size=imsize)\n",
    "        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))\n",
    "        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)\n",
    "        x = dummy_eval(encoder, imsize).detach()\n",
    "\n",
    "        ni = sizes[-1][1]\n",
    "        middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),\n",
    "                                    ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()\n",
    "        x = middle_conv(x)\n",
    "        layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]\n",
    "\n",
    "        for i,idx in enumerate(sz_chg_idxs):\n",
    "            not_final = i!=len(sz_chg_idxs)-1\n",
    "            up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])\n",
    "            do_blur = blur and (not_final or blur_final)\n",
    "            sa = self_attention and (i==len(sz_chg_idxs)-3)\n",
    "            unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,\n",
    "                                   act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()\n",
    "            layers.append(unet_block)\n",
    "            x = unet_block(x)\n",
    "\n",
    "        ni = x.shape[1]\n",
    "        if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))\n",
    "        layers.append(ResizeToOrig())\n",
    "        if last_cross:\n",
    "            layers.append(MergeLayer(dense=True))\n",
    "            ni += in_channels(encoder)\n",
    "            layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))\n",
    "        layers += [ConvLayer(ni, n_out, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]\n",
    "        apply_init(nn.Sequential(layers[3], layers[-2]), init)\n",
    "        #apply_init(nn.Sequential(layers[2]), init)\n",
    "        if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "        layers.append(ToTensorBase())\n",
    "        super().__init__(*layers)\n",
    "\n",
    "    def __del__(self):\n",
    "        if hasattr(self, \"sfs\"): self.sfs.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision.models import resnet34"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = resnet34()\n",
    "m = nn.Sequential(*list(m.children())[:-2])\n",
    "tst = DynamicUnet(m, 5, (128,128), norm_type=None)\n",
    "x = cast(torch.randn(2, 3, 128, 128), TensorImage)\n",
    "y = tst(x)\n",
    "test_eq(y.shape, [2, 5, 128, 128])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = DynamicUnet(m, 5, (128,128), norm_type=None)\n",
    "x = torch.randn(2, 3, 127, 128)\n",
    "y = tst(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.image_sequence.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.azureml.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
